#!/bin/env python2.7
# encoding: utf-8
'''
Created on Oct 29, 2013

@author: mkiyer
'''
import sys
import os
import argparse
import logging
import shutil
import re
import itertools
from multiprocessing import Process, Queue

# set matplotlib backend
import matplotlib
matplotlib.use('Agg')

# third-party packages
import numpy as np
from jinja2 import Environment, PackageLoader
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Polygon

# local imports
from ssea.lib.kernel import ssea_kernel, RandomState
from ssea.lib.base import Result, SampleSet
from ssea.lib.config import Config
from ssea.lib.countdata import BigCountMatrix

# setup path to web files
import ssea
SRC_WEB_PATH = os.path.join(ssea.__path__[0], 'web')

# constants
REPORT_DIR = 'report'

# setup html template environment
env = Environment(loader=PackageLoader("ssea", "templates"),
                  extensions=["jinja2.ext.loopcontrols"])

# matplotlib static figure for plotting
global_fig = plt.figure(0)

class SSEAData:
    pass

class ReportConfig(object):
    def __init__(self):
        self.num_processes = 1
        self.input_dir = None
        self.matrix_dir = None
        self.output_dir = None
        self.create_html = False
        self.create_pdf = False
        self.create_png = False
        self.create_expr_png = False
        self.create_expr_pdf = False
        self.col_metadata_file = None
        self.row_metadata_file = None

    def get_argument_parser(self, parser=None):
        if parser is None:
            parser = argparse.ArgumentParser()
        parser.add_argument("-v", "--verbose", dest="verbose", 
                            action="store_true", default=False, 
                            help="set verbosity level [default: %(default)s]")
        parser.add_argument('-p', '--num-processes', dest='num_processes',
                            type=int, default=1,
                            help='Number of processor cores available '
                            '[default=%(default)s]')
        parser.add_argument('-o', '--output-dir', dest='output_dir',
                            help='Output directory [default will add '
                            '"report" directory to input folder')
        parser.add_argument('input_dir')
        parser.add_argument('matrix_dir')
        grp = parser.add_argument_group("SSEA Report Options")
        grp.add_argument('--html', dest="create_html", 
                         action="store_true", default=self.create_html,
                         help='Create detailed html reports')
        grp.add_argument('--pdf', dest="create_pdf", 
                         action="store_true", default=self.create_pdf,
                         help='Create PDF plots')
        grp.add_argument('--png', dest="create_png", 
                         action="store_true", default=self.create_png,
                         help='Create PNG plots')
        grp.add_argument('--exprpdf', dest="create_expr_pdf", 
                         action="store_true", default=self.create_expr_pdf,
                         help='Create expression PDF plots')
        grp.add_argument('--exprpng', dest="create_expr_png", 
                         action="store_true", default=self.create_expr_png,
                         help='Create expression PNG plots')
        grp.add_argument('--colmeta', dest='col_metadata_file',
                         help='file containing metadata corresponding to each '
                         'column of the weight matrix file')
        grp.add_argument('--rowmeta', dest='row_metadata_file',
                         help='file containing metadata corresponding to each '
                         'row of the weight matrix file')
        return parser        

    def log(self, log_func=logging.info):
        log_func("Parameters")
        log_func("----------------------------------")
        log_func("num processes:               %d" % (self.num_processes))
        log_func("input directory:             %s" % (self.input_dir))
        log_func("matrix directory:            %s" % (self.matrix_dir))
        log_func("output directory:            %s" % (self.output_dir))
        log_func("create html report:          %s" % (self.create_html))
        log_func("create pdf plots:            %s" % (self.create_pdf))
        log_func("create png plots:            %s" % (self.create_png))
        log_func("create expression pdf plots: %s" % (self.create_expr_pdf))
        log_func("create expression png plots: %s" % (self.create_expr_png))
        log_func("col metadata file:           %s" % (self.col_metadata_file))
        log_func("row metadata file:           %s" % (self.row_metadata_file))
        log_func("----------------------------------")

    @staticmethod
    def parse_args(parser=None):
        # create Config instance
        self = ReportConfig()
        # parse command line arguments
        parser = self.get_argument_parser(parser)
        args = parser.parse_args()
        # setup logging
        if args.verbose > 0:
            level = logging.DEBUG
        else:
            level = logging.INFO
        logging.basicConfig(level=level,
                            format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
        # process and check arguments
        self.create_html = args.create_html
        self.create_pdf = args.create_pdf
        self.create_png = args.create_png
        self.create_expr_pdf = args.create_expr_pdf
        self.create_expr_png = args.create_expr_png
        self.num_processes = args.num_processes
        # check input directory
        if not os.path.exists(args.input_dir):
            parser.error("input directory '%s' not found" % (args.input_dir))
        self.input_dir = args.input_dir
        # check matrix directory
        if not os.path.exists(args.matrix_dir):
            parser.error("matrix directory '%s' not found" % (args.matrix_dir))
        self.matrix_dir = args.matrix_dir        
        # setup output directory
        if args.output_dir is None:
            self.output_dir = os.path.join(self.input_dir, REPORT_DIR)
        else:
            self.output_dir = args.output_dir
        if os.path.exists(self.output_dir):
            parser.error("Report directory '%s' already exists" % 
                         (self.output_dir))
        if args.col_metadata_file is not None:
            if not os.path.exists(args.col_metadata_file):
                parser.error("Column metadata file '%s' not found" % (args.col_metadata_file))
            self.col_metadata_file = os.path.abspath(args.col_metadata_file)
        else:
            self.col_metadata_file = None
        if args.row_metadata_file is not None:
            if not os.path.exists(args.row_metadata_file):
                parser.error("Row metadata file '%s' not found" % (args.row_metadata_file))
            self.row_metadata_file = os.path.abspath(args.row_metadata_file)
        else:
            self.row_metadata_file = None
        return self

def parse_metadata(filename, names):
    d = {}
    with open(filename) as f:
        header_fields = f.next().strip().split('\t')
        for line in f:
            fields = line.strip().split('\t')
            d[fields[0]] = dict(zip(header_fields, fields))
    m = {}
    for i,name in enumerate(names):
        if name in d:
            m[i] = d[name]
    return m

def get_library_sizes(bm):
    # get total counts per library
    logging.debug('Getting total counts per library')
    lib_sizes = np.empty(bm.shape[1], dtype=np.float)
    for j in xrange(bm.shape[1]):
        a = bm.counts_t[j,:]
        a = a[np.isfinite(a)]
        lib_sizes[j] = a.sum()
    lib_sizes /= 1.0e6
    return lib_sizes

def ssea_rerun(t_id, rand_seed, bm, ssea_config, sample_set):
    # get membership array for sample set
    membership = sample_set.get_array(bm.colnames)
    valid_samples = (membership >= 0)
    # read from memmap
    counts = np.array(bm.counts[t_id,:], dtype=np.float)
    # remove 'nan' values
    valid_inds = np.logical_and(valid_samples, np.isfinite(counts))
    # subset counts, size_factors, and membership array
    counts = counts[valid_inds]
    size_factors = bm.size_factors[valid_inds]
    membership = membership[valid_inds]
    # reproduce previous run
    rng = RandomState(rand_seed)
    (ranks, norm_counts, norm_counts_miss, norm_counts_hit, 
     es_val, es_rank, es_run) = \
        ssea_kernel(counts, size_factors, membership, rng,
                    resample_counts=False,
                    permute_samples=False,
                    add_noise=True,
                    noise_loc=ssea_config.noise_loc, 
                    noise_scale=ssea_config.noise_scale,
                    method_miss=ssea_config.weight_miss,
                    method_hit=ssea_config.weight_hit,
                    method_param=ssea_config.weight_param)    
    # make object for plotting
    m = membership[ranks]
    hit_indexes = (m > 0).nonzero()[0]
    d = SSEAData()
    d.counts = counts
    d.membership = m
    d.es = es_val
    d.es_rank = es_rank
    d.es_run = es_run
    d.hit_indexes = hit_indexes
    d.ranks = ranks
    d.sample_inds = valid_inds.nonzero()[0][ranks]
    d.raw_weights = norm_counts[ranks]
    d.weights_miss = norm_counts_miss[ranks]
    d.weights_hit = norm_counts_hit[ranks]
    return d

def plot_enrichment(result, sseadata, title, fig=None):
    if fig is None:
        fig = plt.Figure()
    else:
        fig.clf()
    gs = gridspec.GridSpec(3, 1, height_ratios=[2,1,1])
    # running enrichment score
    ax0 = fig.add_subplot(gs[0])
    x = np.arange(len(sseadata.es_run))
    y = sseadata.es_run
    p1 = ax0.scatter(result.resample_es_ranks, result.resample_es_vals,
                     c='r', s=25.0, alpha=0.3, edgecolors='none')
    p2 = ax0.scatter(result.null_es_ranks, result.null_es_vals,
                     c='b', s=25.0, alpha=0.3, edgecolors='none')
    ax0.plot(x, y, lw=2, color='k', label='Enrichment profile')
    ax0.axhline(y=0, color='gray')
    ax0.axvline(x=sseadata.es_rank, lw=1, linestyle='--', color='black')
    ax0.set_xlim((0, len(sseadata.es_run)))
    #ax0.set_ylim((-1.0, 1.0))
    ax0.grid(True)
    ax0.set_xticklabels([])
    ax0.set_ylabel('Enrichment score (ES)')
    ax0.set_title(title)
    legend = ax0.legend((p1,p2), ('Resampled ES', 'Null ES'), 'upper right',
                        numpoints=1, scatterpoints=1, 
                        prop={'size': 'xx-small'})
    # The frame is matplotlib.patches.Rectangle instance surrounding the legend.
    frame = legend.get_frame()
    frame.set_linewidth(0)
    # membership in sample set
    ax1 = fig.add_subplot(gs[1])
    ax1.bar(np.arange(len(sseadata.membership)), sseadata.membership, width=1, linewidth=0, color='black')
    #if len(sseadata.hit_indexes) > 0:
    #    ax1.vlines(sseadata.hit_indexes, ymin=0, ymax=1, lw=0.5, 
    #               color='black', label='Hits')
    ax1.set_xlim((0, len(sseadata.membership)))
    ax1.set_ylim((0, 1))
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax1.set_ylabel('Set')
    # weights
    ax2 = fig.add_subplot(gs[2])
    # TODO: if hit and miss weights differ add a legend
    ax2.plot(sseadata.weights_miss, color='blue')
    ax2.plot(sseadata.weights_hit, color='red')
    #ax2.plot(weights_hit, color='red')
    ax2.set_xlim((0, len(sseadata.es_run)))
    ax2.set_xlabel('Samples')
    ax2.set_ylabel('Weights')
    # TODO: this appears to cause errors in a platform-dependent manner
    #fig.tight_layout()
    return fig

def plot_expression(t_id, transcript_length, lib_sizes, sample_set, bm, title, fig):
    # get membership array for sample set
    membership = sample_set.get_array(bm.colnames)
    valid_samples = (membership >= 0)
    # read from memmap
    counts = np.array(bm.counts[t_id,:], dtype=np.float)
    # remove 'nan' values
    valid_inds = np.logical_and(valid_samples, np.isfinite(counts))
    # subset counts, size_factors, and membership array
    counts = counts[valid_inds]
    lib_sizes = lib_sizes[valid_inds]
    membership = membership[valid_inds]
    # normalize
    fpkm = (counts / lib_sizes) / (transcript_length / 1000.0)
    # subset into hits/misses
    fpkm_hits = fpkm[membership == 1]
    fpkm_misses = fpkm[membership == 0]
    data = [fpkm_hits, fpkm_misses]
    # plot
    if fig is None:
        fig = plt.Figure()
    else:
        fig.clf()
    gs = gridspec.GridSpec(1, 1, height_ratios=[1])
    # only one subplot
    ax0 = fig.add_subplot(gs[0])
    # draw boxplot
    bp = ax0.boxplot(data, notch=0, sym='+', vert=1, whis=1.5)
    # change colors
    plt.setp(bp['boxes'], color='black')
    plt.setp(bp['whiskers'], color='black')
    plt.setp(bp['fliers'], color='red', marker='+')
    # Add a horizontal grid to the plot, but make it very light in color
    # so we can use it for reading data values but not be distracting
    ax0.yaxis.grid(True, linestyle='-', which='major', color='lightgrey',
                   alpha=0.5)
    # Hide these grid behind plot objects
    ax0.set_axisbelow(True)
    ax0.set_title(title)
    ax0.set_xlabel('Sample Set: %s' % (sample_set.name))
    ax0.set_ylabel('FPKM')
    # Now fill the boxes with desired colors
    boxColors = ['red', 'royalblue']
    medians = range(len(data))
    means = range(len(data))
    for i in range(len(data)):
        box = bp['boxes'][i]
        boxX = []
        boxY = []
        for j in range(5):
            boxX.append(box.get_xdata()[j])
            boxY.append(box.get_ydata()[j])
        boxCoords = zip(boxX,boxY)
        # Alternate between Dark Khaki and Royal Blue
        k = i % 2
        boxPolygon = Polygon(boxCoords, facecolor=boxColors[k])
        ax0.add_patch(boxPolygon)
        # Now draw the median lines back over what we just filled in
        med = bp['medians'][i]
        medianX = []
        medianY = []
        for j in range(2):
            medianX.append(med.get_xdata()[j])
            medianY.append(med.get_ydata()[j])
            plt.plot(medianX, medianY, 'k')
            medians[i] = medianY[0]
        # Finally, overplot the sample averages, with horizontal alignment
        # in the center of each box
        means[i] = np.average(data[i])
        ax0.plot([np.average(med.get_xdata())], [means[i]],
                 color='w', marker='*', markeredgecolor='k')

    # Set the axes ranges and axes labels
    bottom = -1.0
    top = 1.1*fpkm.max()
    ax0.set_xlim(0.5, len(data)+0.5)
    ax0.set_ylim(bottom, top)
    xtickNames = plt.setp(ax0, xticklabels=('Hits', 'Misses'))
    plt.setp(xtickNames, rotation=90, fontsize=8)

    # Due to the Y-axis scale being different across samples, it can be
    # hard to compare differences in medians across the samples. Add upper
    # X-axis tick labels with the sample means to aid in comparison
    # (just use two decimal places of precision)
    pos = np.arange(len(data))+1
    upperLabels = [str(np.round(s, 2)) for s in means]
    weights = ['semibold', 'semibold']
    for tick,label in zip(range(len(data)),ax0.get_xticklabels()):
        k = tick % 2
        ax0.text(pos[tick], top-(top*0.05), upperLabels[tick], 
                 horizontalalignment='center', size='x-small', 
                 weight=weights[k], color=boxColors[k])
    return fig

def create_report(result, bm, lib_sizes, config, row_metadata, col_metadata, 
                  ssea_config, sample_set):
    filedict = {}
    # rerun ssea
    ssea_data = ssea_rerun(result.t_id, result.rand_seed, bm, ssea_config, sample_set)
    rowmeta = row_metadata[result.t_id]
    # enrichment plot
    create_enrichment_plot = config.create_pdf or config.create_png
    if create_enrichment_plot:
        rowmeta = row_metadata[result.t_id]
        title = 'Enrichment plot: %s vs. %s' % (rowmeta['transcript_id'], sample_set.name)
        fig = plot_enrichment(result, ssea_data, 
                              title=title, 
                              fig=global_fig)
        if config.create_png:
            eplot_png = '%s.eplot.png' % (rowmeta['transcript_id'])
            fig.savefig(os.path.join(config.output_dir, eplot_png))
            filedict['eplot_png'] = eplot_png
        if config.create_pdf:
            eplot_pdf = '%s.eplot.pdf' % (rowmeta['transcript_id'])
            fig.savefig(os.path.join(config.output_dir, eplot_pdf))
            filedict['eplot_pdf'] = eplot_pdf
    # expression plot
    create_expr_plot = config.create_expr_pdf or config.create_expr_png
    if create_expr_plot:    
        transcript_length = int(rowmeta['transcript_length'])
        title = 'Expression plot: %s' % (rowmeta['transcript_id'])
        fig = plot_expression(result.t_id, transcript_length, 
                              lib_sizes, sample_set, bm, 
                              title=title, fig=global_fig)
        if config.create_expr_png:
            exprplot_png = '%s.fpkm.png' % (rowmeta['transcript_id'])
            fig.savefig(os.path.join(config.output_dir, exprplot_png))
            filedict['exprplot_png'] = exprplot_png
        if config.create_expr_pdf:
            exprplot_pdf = '%s.fpkm.pdf' % (rowmeta['transcript_id'])
            fig.savefig(os.path.join(config.output_dir, exprplot_pdf))
            filedict['exprplot_pdf'] = exprplot_pdf
    # details of individual samples
    create_details = config.create_html
    if create_details:
        # show details of hits
        rows = [['index', 'sample', 'rank', 'raw_weights', 'transformed_weights',
                 'running_es', 'core_enrichment']]
        for i,ind in enumerate(ssea_data.hit_indexes):
            if ssea_data.es < 0:
                is_enriched = int(ind >= ssea_data.es_rank)
            else:                
                is_enriched = int(ind <= ssea_data.es_rank)
            colmeta = col_metadata[ssea_data.sample_inds[ind]]
            rows.append([i, colmeta['library_id'], ind+1, 
                         ssea_data.raw_weights[ind],
                         ssea_data.weights_hit[ind], 
                         ssea_data.es_run[ind], 
                         is_enriched])
        # write details to file
        details_tsv = '%s.tsv' % (rowmeta['transcript_id'])
        with open(os.path.join(config.output_dir, details_tsv), 'w') as fp:
            for row in rows:
                print >>fp, '\t'.join(map(str,row))
        filedict['tsv'] = details_tsv
        # render to html
        if config.create_html:
            details_html = '%s.html' % (rowmeta['transcript_id'])
            t = env.get_template('details.html')
            with open(os.path.join(config.output_dir, details_html), 'w') as fp:
                print >>fp, t.render(result=result,
                                     sseadata=ssea_data,
                                     details=rows,
                                     files=filedict,
                                     rowmeta=rowmeta,
                                     sample_set=sample_set)
            filedict['html'] = details_html
    return filedict

def _producer_process(input_queue, filename, t_ids, num_processes):
    def parse_results(filename):
        with open(filename, 'r') as fp:
            for line in fp:
                result = Result.from_json(line.strip())
                yield result
    for result in parse_results(filename):
        if result.t_id in t_ids:
            input_queue.put(result)
    # tell consumers to stop
    for i in xrange(num_processes):
        input_queue.put(None)
    logging.debug("Producer finished")

def _worker_process(input_queue, output_queue, config, row_metadata, 
                    col_metadata, ssea_config, sample_set): 
    # open data matrix
    bm = BigCountMatrix.open(config.matrix_dir)
    logging.debug('Getting total fragments per library')
    lib_sizes = get_library_sizes(bm)    
    # process results
    while True:
        result = input_queue.get()
        if result is None:
            break
        d = create_report(result, bm, lib_sizes, config, row_metadata, col_metadata, 
                          ssea_config, sample_set)
        # update result with location of files
        result.files = d
        # send result back
        output_queue.put(result)
    bm.close()
    # send done signal
    output_queue.put(None)
    logging.debug("Worker finished")

def create_html_report(input_file, output_file, row_metadata, sample_set):
    def _result_parser(filename):
        with open(filename, 'r') as fp:
            for line in fp:
                result = Result.from_json(line.strip())
                rowmeta = row_metadata[result.t_id]
                result.sample_set_name = sample_set.name
                result.sample_set_desc = sample_set.desc
                result.sample_set_size = len(sample_set)
                result.name = rowmeta['transcript_id']
                result.params = dict(rowmeta)
                del result.params['transcript_id']
                yield result
    # get full list of parameters in row metadata
    allparams = set()
    for v in row_metadata.itervalues():
        allparams.update(v)
    allparams = sorted(allparams)
    # render templates
    t = env.get_template('report.html')
    with open(output_file, 'w') as fp:
        print >>fp, t.render(name=sample_set.name,
                             params=allparams,
                             results=_result_parser(input_file))

def report(config):
    # create output dir
    if not os.path.exists(config.output_dir):
        logging.debug("Creating output dir '%s'" % (config.output_dir))
        os.makedirs(config.output_dir)
    # create directory for static web files (CSS, javascript, etc)
    if config.create_html:
        web_dir = os.path.join(config.output_dir, 'web')
        if not os.path.exists(web_dir):
            logging.debug("Installing web files")
            shutil.copytree(SRC_WEB_PATH, web_dir)
    # open expression matrix to read metadata
    bm = BigCountMatrix.open(config.matrix_dir)
    # read metadata
    if config.row_metadata_file is not None:
        logging.info("Reading row metadata")
        row_metadata = parse_metadata(config.row_metadata_file, bm.rownames)
    else:
        row_metadata = dict((i,{'transcript_id': x}) for i,x in enumerate(bm.rownames))
    if config.col_metadata_file is not None:
        logging.info("Reading column metadata")
        col_metadata = parse_metadata(config.col_metadata_file, bm.colnames)
    else:
        col_metadata = dict((i,{'library_id': x}) for i,x in enumerate(bm.colnames))
    bm.close()
    # produce detailed reports
    logging.info("Creating detailed reports with %d processes" % (config.num_processes))
    sample_set_json_file = os.path.join(config.input_dir,
                                        Config.SAMPLE_SET_JSON_FILE)
    sample_set = SampleSet.parse_json(sample_set_json_file)[0]
    ssea_config_json_file = os.path.join(config.input_dir, 
                                         Config.CONFIG_JSON_FILE)
    ssea_config = Config.parse_json(ssea_config_json_file)
    results_json_file = os.path.join(config.input_dir, Config.RESULTS_JSON_FILE)
    report_json_file = os.path.join(config.output_dir, 'report.json')
    # create multiprocessing queues for passing data
    input_queue = Queue(maxsize=config.num_processes*3)
    output_queue = Queue(maxsize=config.num_processes*3)    
    # start a producer process
    logging.debug("Starting producer process and %d workers" % (config.num_processes))
    args = (input_queue, results_json_file, set(row_metadata.keys()), 
            config.num_processes)
    producer = Process(target=_producer_process, args=args)
    producer.start()
    # start consumer processes
    procs = []
    for i in xrange(config.num_processes):
        args = (input_queue, output_queue, config, row_metadata, 
                col_metadata, ssea_config, sample_set)
        p = Process(target=_worker_process, args=args)
        p.start()
        procs.append(p)
    fp = open(report_json_file, 'w')
    # get results from consumers
    num_alive = config.num_processes
    while num_alive > 0:
        result = output_queue.get()
        if result is None:
            num_alive -= 1
            logging.debug("Main process detected worker finished, %d still alive" % (num_alive))
            continue
        # write result to tab-delimited text
        print >>fp, result.to_json()
    logging.debug("Joining all processes")
    # wait for producer to finish
    producer.join()
    # wait for consumers to finish
    for p in procs:
        p.join()
    fp.close()
    # produce report
    if config.create_html:
        logging.info("Writing HTML report")
        html_file = os.path.join(config.output_dir, 'report.html')
        create_html_report(report_json_file, html_file,
                           row_metadata, sample_set)
    logging.info("Done.")
    return 0

def main(argv=None):
    # Parse command line
    config = ReportConfig.parse_args()
    config.log()
    return report(config)

if __name__ == "__main__":
    sys.exit(main())
